################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################

import ast
import ctypes
import inspect
import textwrap
from typing import Generic, TypeVar

from cuda import cuda, cudart
import numpy as np
from treelib import Tree

from cutlass.backend.epilogue import (
    AccumulatorOp,
    BinaryOp,
    ColumnBroadcastOp,
    ColumnReductionOp,
    RowBroadcastOp,
    RowReductionOp,
    TensorInputOp,
    TensorOutputOp,
    UnaryOp,
)
from cutlass.backend.frontend import NumpyFrontend
from cutlass.backend.utils.software import SubstituteTemplate
import cutlass.backend as backend

################################################################################
# Type annotation for input arguments
################################################################################

Ttype = TypeVar("Ttype")
Dtype = TypeVar("Dtype")


class NDArray(np.ndarray, Generic[Ttype, Dtype]):
    pass


################################################################################
# Operations
################################################################################

operators = {
    ast.Add: "Add",
    ast.Div: "Div",
    ast.Eq: "Equal",
    ast.Mult: "Mult",
}


################################################################################
# AST Node abstractions
################################################################################
class UnaryNode:
    cnt = 0

    # Concept: this is created by the BinOp Node in python ast
    def __init__(
        self,
        element_accumulator,
        element_compute,
        elements_per_access,
        node,
        args,
    ) -> None:
        if isinstance(node, BinOpNode):
            self.op = node.op
        elif isinstance(node, ast.Call):
            if isinstance(node.func, ast.Name):
                self.op = node.func.id
            elif isinstance(node.func, ast.Attribute):
                self.op = node.func.value.id
            else:
                raise TypeError
        else:
            raise TypeError
        self.tag = "Unary" + self.op + str(UnaryNode.cnt)
        self.id = self.op + str(UnaryNode.cnt)
        self.args = args
        UnaryNode.cnt += 1

        self.type = "tensor"

        self.epilogue_op = getattr(backend, self.op)(element_compute)

        # data types
        self.element_accumulator = element_accumulator
        self.element_compute = element_compute
        self.elements_per_access = elements_per_access

    def get_epilogue_node(self, visitors):
        self.epilogue_node = UnaryOp(
            self.element_accumulator,
            self.element_compute,
            self.elements_per_access,
            *visitors,
            self.epilogue_op,
        )

    def get_argument(self, visitor_args, kwargs):
        epilogue_ops = []
        for arg in self.args:
            try:
                epilogue_ops.append(kwargs[arg])
            except:
                epilogue_ops.append(arg)  # direct arguments like constant
        self.argument = self.epilogue_node.argument_type(
            self.epilogue_op.argument_type(*epilogue_ops),
            *visitor_args,
        )


class BinOpNode:
    cnt = 0

    # Concept: this is created by the BinOp Node in python ast
    def __init__(
        self,
        element_accumulator,
        element_compute,
        elements_per_access,
        node,
    ) -> None:
        self.op = operators[type(node.op)]
        self.tag = "Binary" + self.op + str(BinOpNode.cnt)
        self.id = self.op + str(BinOpNode.cnt)
        self.args = None
        BinOpNode.cnt += 1

        self.type = "tensor"

        self.epilogue_op = getattr(backend, "Vector" + self.op)(element_compute)

        # data types
        self.element_accumulator = element_accumulator
        self.element_compute = element_compute
        self.elements_per_access = elements_per_access

    def get_epilogue_node(self, visitors):
        self.epilogue_node = BinaryOp(
            self.element_accumulator,
            self.element_compute,
            self.elements_per_access,
            *visitors,
            self.epilogue_op,
        )

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            self.epilogue_op.argument_type(self.args),
            *visitor_args,
        )


class NameNode:
    # Concept: this is created by the Name Node in python ast
    def __init__(self, node) -> None:
        try:
            self.id = node.id
        except:
            self.id = node.targets[0].id
        self.tag = self.id


class ScalarInputNode(NameNode):
    # Concept: scalar
    def __init__(self, node) -> None:
        super().__init__(node)
        self.tag = "Scalar:" + self.tag
        self.type = "scalar"


class AccumulatorNode(NameNode):
    # Concept: VisitorOpAccumulator
    def __init__(
        self,
        element_accumulator,
        elements_per_access,
        node,
    ) -> None:
        super().__init__(node)
        self.tag = "Accum:" + self.tag
        self.type = "tensor"

        self.element_accumulator = element_accumulator
        self.elements_per_access = elements_per_access

    def get_epilogue_node(self, visitors):
        self.epilogue_node = AccumulatorOp(
            self.element_accumulator,
            self.elements_per_access,
        )

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type()


class TensorInputNode(NameNode):
    # Concept: VisitorOpTensorInput
    def __init__(self, element_accumulator, node) -> None:
        super().__init__(node)
        self.tag = "TensorInput:" + self.tag
        self.type = "tensor"
        self.element_accumulator = element_accumulator

    def get_epilogue_node(self, *args):
        self.epilogue_node = TensorInputOp(self.element_accumulator)

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            kwargs[self.id + "_ptr"],
            kwargs["problem_size"][1],
            kwargs["problem_size"][0] * kwargs["problem_size"][1],
        )


class RowBroadcastNode(NameNode):
    # Concept: VisitorOpRowBroadcast
    def __init__(
        self,
        element_accumulator,
        element_fragment,
        node,
    ) -> None:
        super().__init__(node)
        #
        self.tag = "RowBroadcast:" + self.tag
        self.type = "tensor"
        self.element_accumulator = element_accumulator
        self.element_fragment = element_fragment

    def get_epilogue_node(self, *args):
        self.epilogue_node = RowBroadcastOp(
            self.element_accumulator,
            self.element_fragment,
        )

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            kwargs[self.id + "_ptr"],
            kwargs["problem_size"][1],
        )


class ColumnBroadcastNode(NameNode):
    # Concept: VisitorOpColumnBroadcast
    def __init__(
        self,
        element_accumulator,
        element_fragment,
        node,
    ) -> None:
        super().__init__(node)
        self.tag = "ColumnBroadcast:" + self.tag
        self.type = "tensor"
        self.element_accumulator = element_accumulator
        self.element_fragment = element_fragment

    def get_epilogue_node(self, *args):
        self.epilogue_node = ColumnBroadcastOp(
            self.element_accumulator,
            self.element_fragment,
        )

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            kwargs[self.id + "_ptr"],
            kwargs["problem_size"][0],
        )


class TensorOutputNode(NameNode):
    # Concept: VisitorOpTensorOutput
    def __init__(self, element_accumulator, node) -> None:
        super().__init__(node)
        self.tag = "TensorOutput:" + self.tag
        self.type = "tensor"
        self.element_accumulator = element_accumulator

    def get_epilogue_node(self, visitors):
        self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            kwargs[self.id + "_ptr"],
            kwargs["problem_size"][1],
            *visitor_args,
            kwargs["problem_size"][0] * kwargs["problem_size"][1],
        )


class RowReductionNode:
    # Concept: RowReductionOp
    def __init__(
        self,
        element_accumulator,
        element_reduction,
        element_reduction_accumulator,
        id,
        factor,
    ) -> None:
        #
        self.id = id
        self.tag = "RowReduction:" + self.id
        self.type = "tensor"
        self.element_accumulator = element_accumulator
        self.element_reduction = element_reduction
        self.element_reduction_accumulator = element_reduction_accumulator
        self.factor = factor

    def get_epilogue_node(self, visitors):
        self.epilogue_node = RowReductionOp(
            self.element_accumulator,
            self.element_reduction,
            self.element_reduction_accumulator,
            *visitors,
        )

    def get_batch_stride(self, problem_size):
        return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            kwargs[self.id + "_ptr"],
            *visitor_args,
            self.get_batch_stride(kwargs["problem_size"]),
        )


class ColumnReductionNode:
    # Concept: ColumnReductionOp
    def __init__(
        self,
        element_accumulator,
        element_reduction,
        element_reduction_accumulator,
        id,
        factor,
    ) -> None:
        #
        self.id = id
        self.tag = "ColumnReduction:" + self.id
        self.type = "tensor"
        self.element_accumulator = element_accumulator
        self.element_reduction = element_reduction
        self.element_reduction_accumulator = element_reduction_accumulator
        self.factor = factor

    def get_epilogue_node(self, visitors):
        self.epilogue_node = ColumnReductionOp(
            self.element_accumulator,
            self.element_reduction,
            self.element_reduction_accumulator,
            *visitors,
        )

    def get_batch_stride(self, problem_size):
        return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)

    def get_argument(self, visitor_args, kwargs):
        self.argument = self.epilogue_node.argument_type(
            kwargs[self.id + "_ptr"],
            *visitor_args,
            self.get_batch_stride(kwargs["problem_size"]),
        )


################################################################################
# Epilogue parser function
################################################################################
class EpilogueAST(ast.NodeVisitor):
    def __init__(
        self,
        epilogue,
        tile_description,
        element_accumulator,
        elements_per_access,
        element_compute,
        element_output,
    ) -> None:
        #

        self.tile_description = tile_description
        self.element_accumulator = element_accumulator
        self.elements_per_access = elements_per_access
        self.element_compute = element_compute
        self.element_output = element_output
        self.epilogue = epilogue

        self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
        self.ast_tree = ast.parse(self.source)
        self.epilogue_tree = Tree()

        # print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose

        # input arguments
        self.input_args = {}
        # return nodes
        self.returns = []
        # reduction source nodes
        self.reduction_source = {}

        # stack used to keep the parent node id
        self.stack = []

        # visit the AST
        self.visit(self.ast_tree)

    # visit the name node
    def visit_Name(self, node):
        # append the return ids into self.returns
        if self.stack[-1] == "return":
            self.returns.append(node.id)
        else:
            # accum is produced from accumulator node
            if node.id == "accum":
                name_node = AccumulatorNode(
                    self.element_accumulator,
                    self.elements_per_access,
                    node,
                )
            else:
                # for input nodes
                if node.id in self.input_args.keys():
                    type = self.input_args[node.id][0]
                    if type == "tensor":
                        name_node = TensorInputNode(
                            self.element_accumulator,
                            node,
                        )
                    elif type == "row":
                        name_node = RowBroadcastNode(
                            self.element_accumulator,
                            self.element_compute,
                            node,
                        )
                    elif type == "column":
                        name_node = ColumnBroadcastNode(
                            self.element_accumulator,
                            self.element_compute,
                            node,
                        )
                    elif type == "scalar":
                        name_node = ScalarInputNode(node)
                    else:
                        raise ValueError(type)
                # for output nodes
                else:
                    name_node = TensorOutputNode(
                        self.element_accumulator,
                        node,
                    )
            self.epilogue_tree.create_node(
                name_node.tag,
                name_node.id,
                data=name_node,
                parent=self.stack[-1],
            )

    def visit_Assign(self, node):
        pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
        if pre_assign_node is None:
            # The assign is to a root node
            # skip the reduction nodes
            if isinstance(node.value, ast.Call):
                if isinstance(node.value.func, ast.Name):
                    func_type = node.value.func.id
                elif isinstance(node.value.func, ast.Attribute):
                    func_type = node.value.func.value.id
                else:
                    raise TypeError
                if func_type == "reduction_op":
                    self.reduction_source[node.value.args[0].id] = [
                        node.value.args[1].value,
                        node.value.args[2].value,
                        node.targets[0].id,
                    ]
                    return
            name_node = TensorOutputNode(self.element_accumulator, node)
            self.epilogue_tree.create_node(
                name_node.tag,
                name_node.id,
                data=name_node,
            )
            self.stack.append(name_node.id)
        else:
            if (
                node.targets[0].id in self.returns
                or node.targets[0].id in self.reduction_source.keys()
            ):
                self.stack.append(node.targets[0].id)
            else:
                self.stack.append(
                    pre_assign_node.predecessor(self.epilogue_tree.identifier)
                )
                self.epilogue_tree.remove_node(node.targets[0].id)

        # get child tag
        self.visit(node.value)
        self.stack.pop()

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name):
            func_type = node.func.id
        elif isinstance(node.func, ast.Attribute):
            func_type = node.func.value.id
        else:
            raise TypeError
        if func_type == "reduction_op":
            self.visit(node.args[0])
        else:
            arg_list = []
            for idx, arg in enumerate(node.args):
                if idx == 0:
                    continue
                if isinstance(arg, ast.Constant):
                    arg_list.append(arg.value)
                elif isinstance(arg, ast.Name):
                    arg_list.append(arg.id)
                else:
                    raise TypeError

            unary_node = UnaryNode(
                self.element_accumulator,
                self.element_compute,
                self.elements_per_access,
                node,
                arg_list,
            )
            self.epilogue_tree.create_node(
                unary_node.tag,
                unary_node.id,
                parent=self.stack[-1],
                data=unary_node,
            )
            self.stack.append(unary_node.id)
            self.visit(node.args[0])
            self.stack.pop()

    def visit_BinOp(self, node):
        binop = BinOpNode(
            self.element_accumulator,
            self.element_compute,
            self.elements_per_access,
            node,
        )
        self.epilogue_tree.create_node(
            binop.tag,
            binop.id,
            data=binop,
            parent=self.stack[-1],
        )
        self.stack.append(binop.id)
        self.visit(node.left)
        self.visit(node.right)
        self.stack.pop()

    def visit_Return(self, node):
        self.stack.append("return")
        self.visit(node.value)
        self.stack.pop()

    # # A function definition
    def visit_FunctionDef(self, node: ast.FunctionDef):
        # visit args
        for arg in node.args.args:
            if arg.arg == "self":
                continue
            if isinstance(arg.annotation, ast.Constant):
                self.input_args[arg.arg] = [
                    arg.annotation.value,
                ]
        # visit the assign in the reverse order
        for idx in range(len(node.body)):
            self.visit(node.body[-1 - idx])

    #
    # Tree optimization pass
    #

    # pass 1: lower Binary to Unary
    def pass_binary_2_unary(self, tree, nid):
        node = tree.get_node(nid)
        if isinstance(node.data, BinOpNode):
            lhs_node = tree.get_node(node.successors(tree.identifier)[0])
            left_type = lhs_node.data.type
            rhs_node = tree.get_node(node.successors(tree.identifier)[1])
            right_type = rhs_node.data.type

            if left_type == "scalar" and right_type == "tensor":
                node.data = UnaryNode(
                    self.element_accumulator,
                    self.element_compute,
                    self.elements_per_access,
                    node.data,
                    [
                        lhs_node.data.id,
                    ],
                )
                node.tag = node.data.tag
                tree.remove_node(lhs_node.data.id)
                self.pass_binary_2_unary(tree, rhs_node.data.id)

            elif left_type == "tensor" and right_type == "scalar":
                node.data = UnaryNode(
                    self.element_accumulator,
                    self.element_compute,
                    self.elements_per_access,
                    node.data,
                    [
                        rhs_node.id,
                    ],
                )
                node.tag = node.data.tag
                tree.remove_node(rhs_node.data.id)
                self.pass_binary_2_unary(tree, lhs_node.data.id)

            else:
                self.pass_binary_2_unary(tree, lhs_node.data.id)
                self.pass_binary_2_unary(tree, rhs_node.data.id)
        else:
            for child in node.successors(tree.identifier):
                self.pass_binary_2_unary(tree, child)

    # pass 2: inject reduction nodes
    def pass_inject_reduction(self, tree, nid):
        node = tree.get_node(nid)
        if isinstance(node.data, TensorOutputNode):
            if node.data.id in self.reduction_source.keys():
                direction = self.reduction_source[node.data.id][0]
                target = self.reduction_source[node.data.id][-1]
                if direction == "row":
                    reduction_node = RowReductionNode(
                        self.element_accumulator,
                        self.element_output,
                        self.element_accumulator,
                        target,
                        self.tile_description.threadblock_shape[1],
                    )
                elif direction == "column":
                    reduction_node = ColumnReductionNode(
                        self.element_accumulator,
                        self.element_output,
                        self.element_accumulator,
                        target,
                        self.tile_description.threadblock_shape[0],
                    )
                else:
                    raise ValueError(direction)
                child_nid = node.successors(tree.identifier)[0]
                # if this output node is injected only for reduction
                if node.data.id not in self.returns:
                    # get reduction config from disc
                    node.data = reduction_node
                    node.tag = reduction_node.tag
                    self.pass_inject_reduction(tree, child_nid)
                # if this output node is also a tensor output, inject reduction as its children
                else:
                    # get child node
                    tree.create_node(
                        reduction_node.tag,
                        reduction_node.id,
                        data=reduction_node,
                        parent=node.data.id,
                    )
                    tree.move_node(
                        child_nid,
                        reduction_node.id,
                    )
                    child = tree.get_node(child_nid)
                    for grand_child in child.successors(tree.identifier):
                        self.pass_inject_reduction(tree, grand_child)
            else:
                for child in node.successors(tree.identifier):
                    self.pass_inject_reduction(tree, child)
        else:
            for child in node.successors(tree.identifier):
                self.pass_inject_reduction(tree, child)

    def pass_inject_epilogue_op(self, tree, nid):
        node = tree.get_node(nid)
        visitors = []
        for child in node.successors(tree.identifier):
            visitors.append(self.pass_inject_epilogue_op(tree, child))

        node.data.get_epilogue_node(visitors)
        return node.data.epilogue_node

    def get_arguments(self, tree, nid, kwargs):
        node = tree.get_node(nid)
        visitor_args = []
        for child in node.successors(tree.identifier):
            visitor_args.append(self.get_arguments(tree, child, kwargs))

        node.data.get_argument(visitor_args, kwargs)
        return node.data.argument


class EpilogueVisitTree:
    KernelTemplate = """
${visitor}

using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
"""

    def __init__(
        self,
        elementwise_functor,
        tile_description,
        element_accumulator,
        elements_per_access,
        element_compute,
        element_output,
    ) -> None:
        #
        # data types
        self.tile_description = tile_description
        self.element_accumulator = element_accumulator
        self.elements_per_access = elements_per_access
        self.element_compute = element_compute
        self.element_output = element_output
        self.elementwise_functor = elementwise_functor
        pass

    def initialize(self):
        function = EpilogueAST(
            self,
            self.tile_description,
            self.element_accumulator,
            self.elements_per_access,
            self.element_compute,
            self.element_output,
        )
        #
        tree = function.epilogue_tree
        self.tree = tree
        function.pass_binary_2_unary(self.tree, self.tree.root)
        function.pass_inject_reduction(self.tree, self.tree.root)
        function.pass_inject_epilogue_op(self.tree, self.tree.root)

        visitor = self.tree.get_node(self.tree.root).data.epilogue_node
        self.visitor = visitor

        class _Argument(ctypes.Structure):
            _fields_ = [
                (
                    "visitor_arg",
                    visitor.argument_type,
                )
            ]

            def __init__(self, **kwargs) -> None:
                # process input args
                _kwargs = {}
                for input_key in function.input_args.keys():
                    if input_key == "accum":
                        continue
                    if function.input_args[input_key][0] == "scalar":
                        continue
                    # tensor input
                    else:
                        setattr(
                            self,
                            "buffer_tensor_" + input_key,
                            NumpyFrontend.argument(
                                kwargs[input_key],
                                False,
                            ),
                        )
                        setattr(
                            self,
                            input_key + "_ptr",
                            int(
                                getattr(
                                    self,
                                    "buffer_tensor_" + input_key,
                                ).ptr
                            ),
                        )
                        _kwargs[input_key + "_ptr"] = getattr(
                            self,
                            input_key + "_ptr",
                        )
                # process the return args
                for ret in function.returns:
                    setattr(
                        self,
                        "buffer_tensor_" + ret,
                        NumpyFrontend.argument(kwargs[ret], True),
                    )
                    setattr(
                        self,
                        ret + "_ptr",
                        int(
                            getattr(
                                self,
                                "buffer_tensor_" + ret,
                            ).ptr
                        ),
                    )
                    _kwargs[ret + "_ptr"] = getattr(self, ret + "_ptr")
                    setattr(
                        self,
                        "host_tensor_" + ret,
                        kwargs[ret],
                    )

                _kwargs.update(kwargs)
                function.get_arguments(tree, tree.root, _kwargs)
                self.visitor_arg = tree.get_node(tree.root).data.argument

            def sync(self, stream_sync=True):
                if stream_sync:
                    (err,) = cudart.cudaDeviceSynchronize()
                    if err != cuda.CUresult.CUDA_SUCCESS:
                        raise RuntimeError("CUDA Error %s" % str(err))

                for ret in function.returns:
                    (err,) = cuda.cuMemcpyDtoH(
                        getattr(
                            self,
                            "host_tensor_" + ret,
                        ),
                        cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
                        getattr(
                            self,
                            "host_tensor_" + ret,
                        ).size
                        * getattr(
                            self,
                            "host_tensor_" + ret,
                        ).itemsize,
                    )
                if err != cuda.CUresult.CUDA_SUCCESS:
                    raise RuntimeError("CUDA Error %s" % str(err))
                pass

        self.epilogue_type = _Argument

    def emit(self, operation):
        values = {
            "visitor": self.visitor.emit(operation),
            "operation_name": operation.procedural_name(),
            "visitor_name": self.visitor.instance_name,
        }
        return SubstituteTemplate(self.KernelTemplate, values)
